package cn.doitedu.ml

import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.linalg
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.UserDefinedFunction

import scala.collection.mutable

object LinearPrice {
  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache").setLevel(Level.WARN)

    val spark = SparkSession
      .builder()
      .appName("房价预测")
      .master("local")
      .getOrCreate()
    import org.apache.spark.sql.functions._
    import spark.implicits._


    val df = spark.read.option("header",true).option("inferSchema",true).csv("userprofile/data/linear/sample")

    val arr2Vec: UserDefinedFunction = udf((arr:mutable.WrappedArray[Double])=>{
      // Vector是一个接口，它有两个实现，一个是DenseVector，一个是SparseVector
      val vector: linalg.Vector = Vectors.dense(arr.toArray)
      vector
    })

    // area,floor,price
    val vecDF = df.select(arr2Vec(array('area,'floor)) as "features",'price )

    vecDF.show(100,false)


    // 构造算法对象
    val linearRegression = new LinearRegression()
      .setRegParam(0.1)   // 正则化参数  防止过拟合
      .setLabelCol("price")
      .setFeaturesCol("features")

    // 训练模型
    val model = linearRegression.fit(vecDF)

    // 加载测试数据
    val test = spark.read.option("header",true).option("inferSchema",true).csv("userprofile/data/linear/test")
    val testVecDF = df.select(arr2Vec(array('area,'floor)) as "features",'price )

    // 用训练好的模型，来对测试数据进行输出的预测
    val result = model.transform(testVecDF)

    result.show(100,false)


    // 评估预测效果
    //rmse 均方根误差
    //mse  均方误差
    //r2
    //mae 平均绝对误差
    val regressionEvaluator = new RegressionEvaluator()
        .setPredictionCol("prediction")
        .setLabelCol("price")
        .setMetricName("rmse") //使用指标名称

    val rmse: Double = regressionEvaluator.evaluate(result)

    println(rmse)

    spark.close()
  }

}
